rm(list = ls())

library(tidyverse)
library(haven)

# ---- Load & filters ----
d <- read_dta("Data/cep_all.dta") %>%
  mutate(
    # GSE labels + merge
    gse = as_factor(gse),
    gse = fct_recode(gse, "Upper"="ABC1", "Upper middle"="C2", "Middle"="C3"),
    gse = fct_collapse(gse, "Lower"=c("D","E")),
    # Zona: only Urban (rename from 'Urbano', drop 'Rural')
    zona_u_r = as_factor(zona_u_r) |> fct_recode("Urban"="Urbano") |> fct_drop(only="Rural"),
    # Bandwidth filters
    filter2 = (treatment_dosage_coup >= -5 & treatment_dosage_coup <= 5),
    filter  = (treatment_dosage_coup >= -1 & treatment_dosage_coup <= 1)
  )

# ---- helpers ----
stars <- function(p) ifelse(p < .01, "***",
                            ifelse(p < .05, "**",
                                   ifelse(p < .10, "*", "")))
fmt1  <- function(x) sprintf("%.1f", x)                     # 1-decimal
fmtp  <- function(p) sprintf("%.2f%s", round(p, 2), stars(p)) # 2-dec p + stars
cell  <- function(x) ifelse(is.na(x) | x == "", "", x)

# N row (counts by arm + total)
row_N <- function(df, filt){
  x <- df[filt & !is.na(df$treat1), ]
  tibble(
    Variable  = "N",
    Control   = format(sum(x$treat1 == 0), big.mark = ","),
    Treatment = format(sum(x$treat1 == 1), big.mark = ","),
    Pval      = "",
    Total     = format(nrow(x), big.mark = ",")
  )
}

# ---- categorical rows (drop empties) ----
balance_cat <- function(df, var, filt){
  df <- df[filt & !is.na(df[[var]]) & !is.na(df$treat1), ]
  if (nrow(df) == 0) return(tibble(Variable=character(), Control=character(),
                                   Treatment=character(), Pval=character(), Total=character()))
  
  df[[var]] <- as_factor(df[[var]]) |> fct_drop()
  df$treat1 <- factor(df$treat1, levels = c(0, 1))
  levs <- levels(df[[var]])
  if (length(levs) == 0) return(tibble(Variable=character(), Control=character(),
                                       Treatment=character(), Pval=character(), Total=character()))
  
  map_dfr(levs, function(lv){
    total_share <- mean(df[[var]] == lv, na.rm = TRUE)
    total_val   <- fmt1(100 * total_share)
    
    n0 <- sum(df$treat1 == "0")
    n1 <- sum(df$treat1 == "1")
    a  <- sum(df[[var]] == lv & df$treat1 == "0")
    b  <- sum(df[[var]] == lv & df$treat1 == "1")
    
    if (n0 == 0 | n1 == 0){
      tibble(Variable = lv, Control = "", Treatment = "", Pval = "", Total = total_val)
    } else {
      p0 <- a / n0; p1 <- b / n1
      p  <- suppressWarnings(prop.test(c(a, b), c(n0, n1))$p.value)
      if (is.na(p)) p <- 1
      tibble(
        Variable  = lv,
        Control   = fmt1(100 * p0),
        Treatment = fmt1(100 * p1),
        Pval      = fmtp(p),
        Total     = total_val
      )
    }
  })
}

# ---- numeric row (Age, years) ----
balance_num <- function(df, var, filt, label){
  df <- df[filt & !is.na(df[[var]]) & !is.na(df$treat1), ]
  if (nrow(df) == 0){
    return(tibble(Variable=label, Control="", Treatment="", Pval="", Total=""))
  }
  df$treat1 <- factor(df$treat1, levels = c(0, 1))
  
  total_mean <- fmt1(mean(df[[var]], na.rm = TRUE))
  
  if (any(table(df$treat1) == 0)){
    return(tibble(Variable=label, Control="", Treatment="", Pval="", Total=total_mean))
  }
  
  m0 <- mean(df[[var]][df$treat1 == 0])
  m1 <- mean(df[[var]][df$treat1 == 1])
  p  <- t.test(df[[var]] ~ df$treat1)$p.value
  if (is.na(p)) p <- 1
  
  tibble(
    Variable  = label,
    Control   = fmt1(m0),
    Treatment = fmt1(m1),
    Pval      = fmtp(p),
    Total     = total_mean
  )
}

# ---- build one cohort block ----
build_block <- function(df, filt, cohort_label,
                        ed_order = NULL, gse_order = c("Upper","Upper middle","Middle","Lower")){
  # header row
  head_row <- tibble(Variable = paste0("\\textbf{", cohort_label, "}"),
                     Control="", Treatment="", Pval="", Total="")
  
  # Education (optionally relevel)
  ed <- balance_cat(df, "ed_levels", filt)
  if (!is.null(ed_order) && nrow(ed)){
    ed$Variable <- factor(ed$Variable, levels = ed_order)
    ed <- arrange(ed, Variable)
    ed$Variable <- as.character(ed$Variable)
  }
  
  # GSE ordered
  gse <- balance_cat(df, "gse", filt)
  if (nrow(gse)){
    gse$Variable <- factor(gse$Variable, levels = gse_order)
    gse <- arrange(gse, Variable)
    gse$Variable <- as.character(gse$Variable)
  }
  
  # Female
  sx <- balance_cat(df, "sexo", filt) %>%
    filter(str_detect(Variable, regex("^Female|^Mujer", TRUE))) %>%
    mutate(Variable="Female")
  if (nrow(sx) == 0) sx <- tibble(Variable=character(), Control=character(),
                                  Treatment=character(), Pval=character(), Total=character())
  
  # Age, Urban, N
  age  <- balance_num(df, "edad", filt, "Age")
  urb  <- balance_cat(df, "zona_u_r", filt) %>% filter(Variable == "Urban")
  if (nrow(urb) == 0) urb <- tibble(Variable=character(), Control=character(),
                                    Treatment=character(), Pval=character(), Total=character())
  nrowt <- row_N(df, filt)
  
  # mark separators: after ed, after gse, after demographics (just before N)
  block <- bind_rows(
    head_row,
    mutate(ed, sep_after = FALSE),
    mutate(gse, sep_after = FALSE),
    mutate(sx, sep_after = FALSE),
    mutate(age, sep_after = FALSE),
    mutate(urb, sep_after = TRUE),   # sep before N
    mutate(nrowt, sep_after = FALSE)
  )
  
  # add sep after the last row of ed and gse
  if (nrow(ed))  block$sep_after[1 + nrow(ed)] <- TRUE
  if (nrow(gse)) block$sep_after[1 + nrow(ed) + nrow(gse)] <- TRUE
  
  block
}

ed_order <- c("High School","Less than College","College or more")

b1 <- build_block(d, d$filter2, "Cohorts: 50–60", ed_order = ed_order)
b2 <- build_block(d, d$filter,  "Cohorts: 54 and 56", ed_order = ed_order)

tbl <- bind_rows(b1, b2)

# ---- Emit RAW LaTeX matching your format (fixed-width Total col) ----
latex_lines <- c(
  "\\begin{table}[!htbp]",
  "\\centering",
  "\\caption{Descriptive Statistics}",
  "\\footnotesize",
  "\\label{desc}",
  "\\begin{tabular}{@{}lcccC{1.1cm}@{}}",
  "\\toprule",
  " & Control & Treatment & p-value diff & Total\\\\",
  "\\midrule"
)

for (i in seq_len(nrow(tbl))){
  r <- tbl[i, ]
  # write row
  latex_lines <- c(latex_lines,
                   sprintf("%s & %s & %s & %s & %s\\\\",
                           r$Variable %>% cell(),
                           r$Control  %>% cell(),
                           r$Treatment%>% cell(),
                           r$Pval     %>% cell(),
                           r$Total    %>% cell()))
  # add \midrule after flagged rows
  if (!is.null(r$sep_after) && isTRUE(r$sep_after)) {
    latex_lines <- c(latex_lines, "\\midrule")
  }
}

# footer note
latex_lines <- c(latex_lines,
                 "\\bottomrule",
                 "\\multicolumn{5}{l}{\\textsuperscript{} *p<.1; **p<.05; ***p<.01. Differences are t-tests (means) or two-sample tests of proportions.}\\\\",
                 "\\end{tabular}",
                 "\\end{table}"
)

# Write to file, or cat to console
writeLines(latex_lines, "Output/desc_stats.tex")
cat(paste(latex_lines, collapse = "\n"))



